# Inference by pymc
import rcsj_model_analysis
from rcsj_model_analysis import simulating_data
from rcsj_model_analysis import io
from rcsj_model_analysis import inference222

import numpy as np
from scipy import stats
from scipy import constants
import scipy.io as spio
from scipy.integrate import quad, simpson, cumtrapz
import pymc as pm
import pytensor.tensor as pt

h_bar = constants.hbar
k_B = constants.Boltzmann
elementary_charge = constants.elementary_charge

def import_package_test():
    '''
    Prints string to confirm the package imported correctly.
    '''
    
    print("Successful import!")
    return "Successful import!"

def rcsj_model1(switching_values, upper_prior,f_j0,Temp):

    logl = inference222.LogLike(inference222.loglike, switching_values)
    start_point = switching_values.max() # This is the point of the largest critical current


# use PyMC to sampler from log-likelihood
    with pm.Model() as model:
        # uniform priors on critical_current
        critical_current = pm.Uniform("critical_current", lower=start_point, upper=0.081)
        f_j0 = f_j0
        Temp = Temp

        # convert critical_current to a tensor vector
        theta = pt.as_tensor_variable([critical_current, f_j0, Temp])

        # use a Potential to "call" the Op and include it in the logp computation
        pm.Potential("likelihood", logl(theta))
    return model


def rcsj_model2(switching_values, upper_prior):

    logl = inference222.LogLike(inference222.loglike, switching_values)
    start_point = switching_values.max() # This is the point of the largest critical current
    model = pm.Model()

    with model:
        # uniform priors on critical_current
        critical_current = pm.Uniform("critical_current", lower=start_point, upper=upper_prior)
        log_f_j0 = pm.Uniform("log_f_j0", lower=1, upper=9)
        log_Temp = pm.Uniform("log_Temp", lower=1, upper=9)

        # convert critical_current, f_j0 and Temp to a tensor vector
        f_j0 = 10**log_f_j0
        Temp = 10**log_Temp
        theta = pt.as_tensor_variable([critical_current, f_j0, Temp])

        # use a Potential to "call" the Op and include it in the logp computation
        pm.Potential("likelihood", logl(theta))
        
    return model

def rcsj_model3(switching_values, upper_prior,Temp,ramp_speed):

    logl = inference222.LogLike_new(inference222.loglike_new, switching_values)
    start_point = switching_values.max() # This is the point of the largest critical current
    model = pm.Model()

    with model:
        # uniform priors on critical_current
        critical_current = pm.Uniform("critical_current", lower=start_point, upper=upper_prior)
        log_f_j0 = pm.Uniform("log_f_j0", lower=8, upper=14)
        log_U_0 = pm.Uniform("log_U_0", lower=-2, upper=0)
        Temp = Temp
        ramp_speed = ramp_speed

        # convert critical_current, f_j0 and Temp to a tensor vector
        f_j0 = 10**log_f_j0
        U_0 = 10**log_U_0
        theta = pt.as_tensor_variable([critical_current, f_j0, Temp, U_0, ramp_speed])

        # use a Potential to "call" the Op and include it in the logp computation
        pm.Potential("likelihood", logl(theta))
        
    return model

def rcsj_model4(switching_values, upper_prior,f_j0, Temp,ramp_speed):

    logl = inference222.LogLike_new(inference222.loglike_new, switching_values)
    start_point = switching_values.max() # This is the point of the largest critical current
    model = pm.Model()

    with model:
        # uniform priors on critical_current
        critical_current = pm.Uniform("critical_current", lower=start_point, upper=upper_prior)
        log_U_0 = pm.Uniform("log_U_0", lower=-2, upper=0)#Changed from -5 to -1
        f_j0 = f_j0
        Temp = Temp
        ramp_speed = ramp_speed

        # convert critical_current, f_j0 and Temp to a tensor vector
        U_0 = 10**log_U_0
        theta = pt.as_tensor_variable([critical_current, f_j0, Temp, U_0, ramp_speed])

        # use a Potential to "call" the Op and include it in the logp computation
        pm.Potential("likelihood", logl(theta))
        
    return model

def rcsj_model5(switching_values, upper_prior,f_j0, Temp,ramp_speed):
    # This model uses uniform prior for U0 to try inference
    logl = inference222.LogLike_new(inference222.loglike_new, switching_values)
    start_point = switching_values.max() # This is the point of the largest critical current
    model = pm.Model()

    with model:
        # uniform priors on critical_current
        critical_current = pm.Uniform("critical_current", lower=start_point, upper=upper_prior)
        U_0 = pm.Uniform("U_0", lower=0.01, upper=1)
        f_j0 = f_j0
        Temp = Temp
        ramp_speed = ramp_speed

        # convert critical_current, f_j0 and Temp to a tensor vector
        
        theta = pt.as_tensor_variable([critical_current, f_j0, Temp, U_0, ramp_speed])

        # use a Potential to "call" the Op and include it in the logp computation
        pm.Potential("likelihood", logl(theta))
        
    return model

def rcsj_model6(switching_values,lower_prior, upper_prior,f_j0,U_0, Temp,ramp_speed):
    # This model uses uniform prior for U0 to try inference
    logl = inference222.LogLike_new(inference222.loglike_new, switching_values)

    model = pm.Model()

    with model:
        # uniform priors on critical_current
        critical_current = pm.Uniform("critical_current", lower=lower_prior, upper=upper_prior)
        U_0 = U_0
        f_j0 = f_j0
        Temp = Temp
        ramp_speed = ramp_speed

        # convert critical_current, f_j0 and Temp to a tensor vector
        
        theta = pt.as_tensor_variable([critical_current, f_j0, Temp, U_0, ramp_speed])

        # use a Potential to "call" the Op and include it in the logp computation
        pm.Potential("likelihood", logl(theta))
        
    return model